Muon Optimizer + FSDP

导言

muon 优化器在FSDP场景下 ,xtuner以及业界先进方法是如何实现的。

xtuner的实现是根据相同shape串性把tensor来all2all,专家tensor不够fsdp_size时,还需要padding。并且内存快照时发现all2all要申请一个大buffer 35B 128卡 sp4 256k,好像有10GB左右。

这份文档围绕两个问题展开:

  1. FSDP + Muon 的 All2All 在内存快照中申请了约 10GB 大 buffer,该如何避免?是否可以分块通信?
  2. veScale-FSDP 中关于 RaggedShard + Muon 的设计有什么值得学习的地方?

先给结论:

  • 10GB All2All buffer 大概率不是 Muon 必然开销,而是 bucket 粒度过大、padding、临时 pack/unpack buffer、Newton-Schulz workspace 叠加导致的峰值。
  • 可以分块通信,但应优先“按矩阵组分块”,不要把单个 Muon 矩阵切碎后分别正交化,否则会改变优化器语义。
  • exact Muon 的下界是:某个 rank 至少要持有一个完整 2D 矩阵及其 Newton-Schulz 临时 workspace。这个下界无法通过普通 All2All 消除。
  • veScale-FSDP 的关键思想不是简单换一个 collective,而是把“结构感知布局”做进 FSDP:RaggedShard 表示不规则 sharding,Planner 减少 padding,DBuffer 做持久化零拷贝通信 buffer。

背景:FSDP 与 Muon 的冲突

Muon 的核心更新不是逐元素操作,而是对 2D 矩阵的 momentum 做近似正交化:

1
2
3
M_t = momentum_update(grad)
U_t = NewtonSchulz(M_t)
W_t = W_t - lr * U_t

AdamW 可以在 FSDP local shard 上直接做,因为它是 element-wise update。
但 Muon 不行,因为:

1
2
3
Orthogonalize([M_0; M_1; ...; M_{S-1}])_r
!=
Orthogonalize(M_r)

也就是说,在每个 FSDP rank 的 shard 上独立跑 Newton-Schulz,不等价于对完整矩阵跑 Muon

因此 exact FSDP + Muon 通常需要:

  1. 每个 rank 持有矩阵的一个 shard;
  2. 通过 All2All / gather 把完整矩阵重组到某些 rank;
  3. 这些 rank 执行 Newton-Schulz;
  4. 再把更新后的 shard 发回原 FSDP layout。

XTuner 那种“同 shape tensor 分桶,然后 All2All;数量不够 FSDP size 时 padding”的方案,本质上是在实现这个 exact 语义。


问题一:为什么 All2All buffer 会到 10GB

假设:

1
2
3
4
FSDP size = S
同 shape 矩阵数量 = B
单个完整矩阵大小 = P bytes
每个 rank 最终负责 C 个完整矩阵

如果实现一次性处理一个大 bucket,那么每个 rank 至少可能需要:

1
2
3
4
5
input pack buffer      ≈ C * P
output full buffer ≈ C * P
Newton-Schulz workspace ≈ alpha * C * P
reverse All2All buffer ≈ C * P
padding/alignment ≈ extra

所以峰值近似是:

1
peak ≈ (2 + alpha + extra) * C * P

其中 alpha 取决于 Newton-Schulz 实现,可能是 1 到 3 甚至更高。

如果一个 35B 模型里某些 MLP 矩阵本身就有数百 MB,而一个 bucket 让每个 rank 同时接收多个完整矩阵,那么 10GB 峰值并不意外。

还要注意:SP4、256k sequence length 本身通常不直接决定 Muon All2All buffer 大小。Muon buffer 主要由模型参数形状、FSDP group、bucket 策略、dtype 决定。但 256k 长序列会让 activation、grad bucket、allocator reserved memory 压力更大,导致 optimizer step 的临时 buffer 更容易成为 OOM 触发点。


分块通信:可以,但要按矩阵分块

推荐分块方式

应该把一个大 Muon bucket 拆成多个 micro-bucket:

1
2
3
4
5
6
7
8
9
10
big bucket:
[W0, W1, W2, ..., W127]

micro-bucket 0:
[W0, W1, ..., W15]

micro-bucket 1:
[W16, W17, ..., W31]

...

每个 micro-bucket 执行:

1
2
3
4
5
6
1. pack local momentum shards
2. All2All: shard layout -> full matrix layout
3. Newton-Schulz on full matrices
4. All2All: full update layout -> original shard layout
5. apply update
6. reuse buffer for next micro-bucket

关键是:每个 Muon 单元仍然是完整矩阵

不推荐的分块方式

不要把单个矩阵切成多个 row block 分别做 Muon:

1
2
3
W = [W_top; W_bottom]

NewtonSchulz(W_top) 和 NewtonSchulz(W_bottom)

这不再是原始 Muon,而是 block-wise Muon / approximate Muon。它可能可用,但优化器语义已经变了,loss 曲线和稳定性都需要重新验证。

例外情况是 fused tensor:

1
2
qkv_proj.weight = [Q; K; V]
experts.weight = [E, out, in]

这类 tensor 可以先按语义拆成多个逻辑矩阵,再分别做 Muon。这个拆分是合理的,因为 Muon 单元本来就应该是每个逻辑矩阵,而不是整个 fused 大 tensor。


内存控制方案

1. 给 Muon 设置独立 bucket cap

不要沿用 FSDP 的通信 bucket 大小,也不要把所有 same-shape tensor 一次性 All2All。

建议引入:

1
2
3
4
5
muon:
comm_bucket_cap_mb: 512 # 起步建议 512MB 或 1GB
max_inflight_buckets: 1 # 内存紧张时先设 1
preallocate_workspace: true
dtype: bf16

估算 micro-bucket 大小:

1
2
3
4
5
6
def estimate_muon_peak(full_bytes, ns_alpha=2.0, reverse_buffer=True):
# full_bytes: 当前 micro-bucket 中每个 rank 需要持有的完整矩阵总 bytes
base = 2 * full_bytes # input + output
ns = ns_alpha * full_bytes
reverse = full_bytes if reverse_buffer else 0
return base + ns + reverse

选择 micro-bucket 时应满足:

1
estimate_muon_peak(bucket) <= muon_workspace_cap

如果当前看到 10GB buffer,可以先把 cap 降到:

1
512MB ~ 1GB full-matrix bytes per rank

然后观察 optimizer step time 的变化。通常这会增加 All2All 次数,但能显著降低峰值显存。


2. 预分配并复用 workspace

不要每个 bucket 都:

1
tmp = torch.empty(...)

而应该在 optimizer 初始化时预分配固定 workspace:

1
2
3
4
5
workspace = MuonWorkspace(
in_buf=torch.empty(max_bytes, dtype=torch.bfloat16, device="cuda"),
out_buf=torch.empty(max_bytes, dtype=torch.bfloat16, device="cuda"),
ns_buf=torch.empty(max_ns_bytes, dtype=torch.bfloat16, device="cuda"),
)

每个 micro-bucket 只使用 narrow/view

1
2
in_view = workspace.in_buf[:needed_numel]
out_view = workspace.out_buf[:needed_numel]

这样可以避免:

  • PyTorch caching allocator 反复申请大块内存;
  • stream lifetime 导致旧 buffer 不能及时复用;
  • memory fragmentation;
  • snapshot 中出现多个临时大块并存。

veScale-FSDP 的 DBuffer 思想也类似:用持久化 distributed buffer 和地址映射来避免反复 copy/alloc。


3. 限制 in-flight bucket 数量

为了 overlap,很多实现会同时挂多个异步 All2All:

1
2
3
chunk 0 communicating
chunk 1 packing
chunk 2 Newton-Schulz

这有利于性能,但会增加峰值显存。

在 35B、128 卡、SP4、256k 这种 activation 压力很高的场景,建议先关闭 aggressive overlap:

1
2
muon:
max_inflight_buckets: 1

稳定后再尝试:

1
2
muon:
max_inflight_buckets: 2

不要一开始就让 3 到 4 个 Muon chunks 同时在飞。


4. 尽量使用 BF16 通信和计算 buffer

检查 All2All buffer 的 dtype。

如果 momentum 或 update buffer 是 FP32:

1
10GB FP32 -> 5GB BF16

Muon 的大规模实现通常会尽量让通信和 Newton-Schulz 主路径使用 BF16 / FP16 / Tensor Core 友好的格式。
但需要注意:这可能影响数值稳定性。建议至少记录:

1
2
3
4
update RMS
grad norm
attention logits max
loss spike

如果 BF16 Muon 不稳定,可以只对 Newton-Schulz 的某些归一化标量保留 FP32,而不是让整个 All2All buffer 变成 FP32。


5. 用 all_to_all_single 的 split sizes 或 ragged all-to-all 减少 padding

XTuner 的 same-shape bucket 通常要求:

1
real_count padded 到 fsdp_size 的倍数

如果 expert tensor 数量少于 FSDP size,会产生严重 padding:

1
2
3
experts = 8
fsdp_size = 64
padding ratio = 87.5%

可以改成两条路径:

1
2
3
4
5
dense/common shape:
equal-size all_to_all_single,走高带宽路径

expert/small/irregular shape:
ragged all_to_all / all_to_all_single(split_sizes) / gather-to-root

如果后端支持 variable split sizes,优先避免 dummy tensor padding。
如果 variable all-to-all 性能不理想,小 bucket 可以退回 P2P gather/scatter 或 AdamW。


6. 跨 layer 合并 expert bucket

不要按 layer 做 expert bucket:

1
2
3
layer0: 8 experts -> pad to 64
layer1: 8 experts -> pad to 64
...

应该跨 layer 合并:

1
2
3
all_layers.experts.up_proj
all_layers.experts.gate_proj
all_layers.experts.down_proj

例如:

1
2
3
4
5
6
7
8
9
10
num_layers = 32
experts_per_layer = 8
fsdp_size = 64

按层 bucket:
每层 8 个,padding 87.5%

跨层 bucket:
32 * 8 = 256 个
可切成 4 个 64-bucket,几乎无 padding

这是 MoE + FSDP + Muon 中非常关键的优化。


7. 对不划算的参数退回 AdamW

你前面已经观察到:在 Qwen 35B SFT 中,Muon 每步 loss 下降未必比 AdamW 快。

因此在 SFT 场景没必要追求 Muon 覆盖率 100%。建议:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
Muon:
attention q/k/v/o projection
dense MLP gate/up/down
大多数 routed expert MLP

AdamW:
embedding
lm_head
norm
bias
router/gate
LoRA 参数
padding ratio 过高的小 expert bucket
极大且导致显存峰值的个别矩阵

可以加策略:

1
2
3
4
5
6
7
8
if bucket.padding_ratio > 0.5:
use_adamw(bucket)

if estimated_muon_peak(bucket) > hard_cap:
split_bucket_or_use_adamw(bucket)

if sft and eval_not_improved_by_muon:
reduce_muon_coverage()

8. HSDP 降低 FSDP group size

veScale-FSDP 论文也给出类似经验:不要盲目扩大 FSDP group size,必要时用 HSDP 控制 collective group。

例如总共 128 卡:

1
2
3
4
5
6
7
8
9
10
11
方案 A:
fsdp_size = 128
dp_replicas = 1

方案 B:
fsdp_size = 64
dp_replicas = 2

方案 C:
fsdp_size = 32
dp_replicas = 4

较小的 fsdp_size 通常可以:

  • 减少 expert padding;
  • 降低 collective group 复杂度;
  • 改善 NCCL latency 和 LCM rounding;
  • 让 bucket 更容易规划。

代价是:

  • 每卡参数 shard、grad shard、optimizer state 变大;
  • DP replica 之间还需要同步梯度或 optimizer state;
  • 总显存不一定下降,需要实测。

对于 35B SFT,如果 activation 才是主要压力,HSDP 未必能直接省显存;但如果 All2All padding 和大 bucket 是主要问题,HSDP 很值得试。


一个推荐的分块执行伪代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class MuonFSDPExecutor:
def __init__(self, fsdp_group, workspace_cap_bytes, max_inflight=1):
self.fsdp_group = fsdp_group
self.workspace = preallocate_workspace(workspace_cap_bytes)
self.max_inflight = max_inflight

def step_bucket(self, bucket):
chunks = plan_micro_buckets(
bucket,
cap_bytes=self.workspace.cap_bytes,
cost_fn=estimate_muon_peak,
)

for chunk in chunks:
# 1. 本地 momentum update
local_m_shards = update_momentum_local(chunk)

# 2. pack 到持久化 input buffer
in_view = self.workspace.pack(local_m_shards)

# 3. shard layout -> full matrix layout
out_view = self.workspace.alloc_output(chunk)
dist.all_to_all_single(
output=out_view,
input=in_view,
group=self.fsdp_group,
)

# 4. 每个 rank 对自己负责的完整矩阵跑 Newton-Schulz
full_mats = unpack_full_matrices(out_view, chunk)
updates = []
for mat in full_mats:
updates.append(newton_schulz(mat))

# 5. pack update,反向 All2All 回原 FSDP shard layout
update_view = self.workspace.pack(updates)
shard_update_view = self.workspace.alloc_shard_output(chunk)
dist.all_to_all_single(
output=shard_update_view,
input=update_view,
group=self.fsdp_group,
)

# 6. apply local shard update
apply_update_local(chunk.params, shard_update_view)

# 7. workspace 逻辑释放,下一 chunk 复用
self.workspace.reset()

重点:

1
2
分块对象:多个完整矩阵组成的 micro-bucket
不要分块对象:单个矩阵的 row shard

veScale-FSDP 的核心思想

veScale-FSDP 论文认为,传统 FSDP 的 element-wise 或 row-wise fixed sharding 难以支持结构感知训练,例如 Muon、Shampoo、block-wise quantization。它提出三个关键组件:

1
2
3
RaggedShard
Structure-aware Planner
DBuffer

RaggedShard

RaggedShard 是一种 DTensor placement,用来表达不规则 sharding。

传统 Shard(0) 通常要求均匀切分:

1
2
3
4
rank0: same size
rank1: same size
rank2: same size
...

RaggedShard 允许:

1
2
3
4
rank0: 0 block
rank1: 0 block
rank2: full tensor
rank3: 0 block

也允许:

1
2
3
4
rank0: 1 unit
rank1: 2 units
rank2: 1 unit
rank3: 1 unit

这对于 Muon 很有用,因为可以把某个完整矩阵重分布到一个 root rank:

1
2
3
4
5
6
original FSDP placement:
each rank owns a shard

RaggedShard(root):
only root owns the full 2D matrix
other ranks own empty tensor

veScale 文档中也明确提到,Muon 的 Newton-Schulz 需要完整 2D 参数矩阵,RaggedShard 可以通过 DTensor.redistribute 表达 gather -> compute -> scatter 这个过程。


Structure-aware Planner

如果只是把 RaggedShard tensor 简单拼起来,可能出现:

1
2
3
4
block 被切碎
tensor 内部插 padding
每个 rank buffer 不均衡
通信 buffer 非连续

veScale 的 planner 目标是:

1
2
3
4
1. 不切碎结构块
2. 保持 tensor contiguous
3. 平衡每个设备的通信负载
4. 尽量把 padding 放在 tensor 之间,而不是 tensor 内部

这点对 Muon 和 MoE 都很重要。

对于你看到的 expert tensor padding,veScale 的思路不是简单“补 dummy tensor 到 fsdp_size”,而是做全局 layout planning,尽量减少 padding 和 LCM rounding。

论文中还给出经验:不要使用过大的 FSDP group size,可以通过 HSDP 控制 shard group,并通过离线模拟选择 padding 最小的 FSDP size。


DBuffer

DBuffer 是 veScale-FSDP 的通信 buffer 抽象。

它的目标是:

1
2
3
4
5
1. 持久化分配通信 buffer
2. 多 tensor group-level 操作
3. zero-copy access
4. in-place communication/computation
5. 降低 PyTorch allocator fragmentation

这正好对应你看到的 10GB 临时 All2All buffer 问题。

如果没有 DBuffer,一个朴素实现通常会反复:

1
2
3
4
5
6
torch.empty(...)
torch.cat(...)
torch.stack(...)
contiguous()
all_to_all(...)
unpack(...)

这会在 memory snapshot 中出现大量临时大块。

借鉴 DBuffer 后,应把 Muon 的通信区改成:

1
2
3
4
初始化时规划地址
初始化时分配最大 workspace
每个 step 使用 view/narrow
不在热路径中频繁申请大 tensor

veScale 的 Distributed Muon 流程

veScale-FSDP 论文中的 Muon 逻辑可以概括为:

1
2
3
4
5
6
7
8
9
10
11
12
for each 2D parameter w:
g = grad(w)
u = MomentumUpdate(g, m)
p = original placement(u)

r = SelectRoot() # 负载均衡选择 root
o = Redistribute(u, RaggedShard(r)) # root 持有完整矩阵

o = NewtonSchulz(o) # 只有 root 真正计算

o = Redistribute(o, p) # 回到原 FSDP shard
w = w - lr * o

这个设计和 XTuner same-shape All2All 的目标类似,都是 exact Muon。
但抽象层次不同:

方案 核心思路 优点 风险
XTuner-style same-shape All2All 同 shape tensor 批量重排 简单,高带宽,容易实现 padding、大 bucket、大 buffer
veScale RaggedShard 用 placement 表达不规则 gather/scatter 语义清晰,减少 padding,适合结构感知 optimizer 需要 DTensor/RaggedShard/Planner 支撑
DBuffer 持久化通信 buffer 和地址映射 降低 allocator 峰值和 copy 工程复杂度更高
Rooted gather 每个矩阵选 root rank 避免同 shape 数量不足 padding 需要负载均衡和异步 overlap

推荐架构

如果你要从当前 XTuner-style FSDP + Muon 演进,我建议分三层做。

第一层:保留 exact All2All,但加 micro-bucket

这是最容易落地的改造。

1
2
3
4
5
6
7
8
目标:
把 10GB 临时 buffer 降到 1GB ~ 2GB 可控范围

做法:
1. same-shape bucket 保留
2. bucket 内按 workspace_cap 切 micro-bucket
3. max_inflight 先设为 1
4. 所有临时 buffer 预分配并复用

建议配置:

1
2
3
4
5
6
7
muon:
exact: true
comm_dtype: bf16
workspace_cap_mb: 1024
max_inflight_buckets: 1
preallocate_workspace: true
fallback_min_fill_ratio: 0.5

第二层:MoE expert 改成全局规划

针对 expert tensor:

1
2
3
1. 跨 layer 合并 same-shape experts
2. padding ratio > 50% 的 bucket 退回 AdamW 或走 ragged path
3. 如果 expert 是 [E, out, in] fused layout,并且 shard 在 E 维,则本地 per-expert Muon,不需要 All2All

判断逻辑:

1
2
3
4
5
6
7
if is_expert_batch_sharded(param):
# local shard already owns complete expert matrices
run_local_per_expert_muon(param)
elif bucket.padding_ratio <= 0.5:
run_all2all_muon(bucket)
else:
run_adamw(bucket)

第三层:引入 RaggedShard / Rooted Muon

当 same-shape + padding 已经成为主要瓶颈时,再引入 veScale 风格设计:

1
2
3
4
5
6
1. 每个 Muon 矩阵选择一个 root rank
2. 使用 ragged placement 表示 root 持有完整矩阵
3. redistribute 到 root
4. root 上 Newton-Schulz
5. redistribute 回原 placement
6. 使用 planner 控制 root 负载和 buffer cap

root 选择可以按 estimated cost 做负载均衡:

1
2
3
4
5
def select_root(matrix, rank_load):
cost = matrix.numel() * matrix.dtype.itemsize
root = min(rank_load, key=rank_load.get)
rank_load[root] += cost
return root

不要简单 round-robin,因为不同矩阵大小差异很大。


针对 35B / 128 卡 / SP4 / 256k 的建议

优先级如下:

  1. 先确认 All2All buffer dtype
    如果是 FP32,优先改成 BF16。

  2. 把 Muon bucket cap 降到 512MB 或 1GB
    先牺牲一点 optimizer step time,换取显存稳定。

  3. 关闭多 chunk overlap
    max_inflight_buckets=1,稳定后再开到 2。

  4. 预分配 Muon workspace
    不要在每个 step、每个 bucket 动态 torch.empty 大 tensor。

  5. 跨 layer 合并 expert bucket
    避免每层 expert 数量小于 FSDP size 导致巨量 padding。

  6. padding ratio 高的 expert bucket 退回 AdamW
    SFT 中 Muon 未必带来收益,不值得为这些参数付出 10GB buffer。

  7. 评估 HSDP
    比如:

    1
    2
    fsdp_size = 64, dp = 2
    fsdp_size = 32, dp = 4

    用离线脚本估算 padding 和 per-rank state,再实测。

  8. 检查 optimizer step 前 activation 是否真正释放
    256k 下 activation 压力极大。可以诊断性地在 optimizer 前插入同步,确认是否是 stream lifetime 导致临时 buffer 共存:

    1
    2
    3
    del loss, outputs
    torch.cuda.synchronize()
    optimizer.step()

    这不是最终性能方案,但可以帮助定位峰值来源。


需要记录的指标

建议每个 Muon bucket 打印:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
bucket_name
logical_shape
num_real_tensors
num_padded_tensors
padding_ratio
full_bytes_per_rank
estimated_comm_buffer_bytes
estimated_ns_workspace_bytes
actual_allocated_before
actual_allocated_after
actual_peak_allocated
all2all_in_time_ms
newton_schulz_time_ms
all2all_out_time_ms

示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
logger.info(
"[muon_bucket] key=%s real=%d padded=%d pad=%.2f "
"full_rank=%.2fGB comm=%.2fGB ns=%.2fGB "
"t_a2a_in=%.2fms t_ns=%.2fms t_a2a_out=%.2fms",
bucket.key,
bucket.real_count,
bucket.padded_count,
bucket.padding_ratio,
full_bytes_per_rank / 2**30,
comm_bytes / 2**30,
ns_bytes / 2**30,
t_in,
t_ns,
t_out,
)

没有这些指标,很难判断 10GB 是来自:

1
2
3
4
5
6
7
bucket 太大
padding 太多
dtype 不对
NS workspace 太大
多 chunk overlap
allocator fragmentation
activation 未释放

结论

对于你的场景,最实际的路线是:

1
2
3
4
5
6
7
8
9
10
11
12
13
短期:
XTuner-style same-shape All2All 保留
加 micro-bucket + workspace cap + BF16 + max_inflight=1
padding 高的 expert bucket 退回 AdamW

中期:
跨 layer expert bucket
expert-batch-sharded fast path
HSDP 调整 fsdp_size

长期:
学 veScale-FSDP
引入 RaggedShard / rooted redistribution / planner / persistent DBuffer

一句话总结:

FSDP + Muon 的核心不是“能不能 All2All”,而是“能否在保持完整矩阵 Muon 语义的同时,把重分布、padding、workspace 和 allocator 生命周期都纳入统一规划”。XTuner 的实现解决了 correctness,veScale-FSDP 的设计进一步解决了 layout、padding 和 buffer 生命周期问题。


参考资料

  • veScale-FSDP paper: veScale-FSDP: Flexible and High-Performance FSDP at Scale 1
  • veScale RaggedShard 文档: RaggedShard Placement 2
  • veScale GitHub: volcengine/veScale 3
  • PyTorch / TorchTitan FSDP2 notes: torchtitan FSDP documentation 4
  • Microsoft Dion / Muon distributed implementation: microsoft/dion 5
Author

Shaojie Tan

Posted on

2026-05-07

Updated on

2026-07-03

Licensed under